{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Generating Crystal Structures with CrystaLLM\n",
        "This notebook is a companion of chapter 7 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to generate and evaluate crystal structures using the [CrystaLLM](https://github.com/lantunes/CrystaLLM) model. It doesn't require hardware acceleration.  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "7gmR2OYoIUIU"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Clone the CrystaLLM repo."
      ],
      "metadata": {
        "id": "7o_RjB3sI9D3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KtrXp7CNU3cF"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/lantunes/CrystaLLM.git"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements. Only Pymatgen (Python Materials Genomics, a robust Open Source library for materials analysis) and OmegaConf (a YAML based hierarchical configuration system) are missing in the Colab VMs."
      ],
      "metadata": {
        "id": "fCNOirgoXaZw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install pymatgen==2023.3.23 omegaconf"
      ],
      "metadata": {
        "id": "mUL5qHm8XcNm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Add the CrystaLLM path to the Python path."
      ],
      "metadata": {
        "id": "qfYsx41OX4Mc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import sys\n",
        "import os\n",
        "\n",
        "sys.path.append('/content/CrystaLLM')\n",
        "os.environ[\"PYTHONPATH\"] += (\":/content/CrystaLLM\")"
      ],
      "metadata": {
        "id": "E6OYlprLYCFj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%cd CrystaLLM"
      ],
      "metadata": {
        "id": "IGpJ3k6yjQlT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download a CrystLLM pretrained model. They aren't available in the HF's Hub, so we have to use the provided ```download.py``` script."
      ],
      "metadata": {
        "id": "GC3tYzWFVzdt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/download.py crystallm_v1_small.tar.gz\n",
        "!tar xvf crystallm_v1_small.tar.gz"
      ],
      "metadata": {
        "id": "KzNPNOGgVGd4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create a prompt to be used for the generation process and save it to file."
      ],
      "metadata": {
        "id": "hvSa-RKLWAp8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/make_prompt_file.py Na2Cl2 sample_prompt.txt --spacegroup P4/nmm"
      ],
      "metadata": {
        "id": "mHQOH_XJV5q2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Do random sampling."
      ],
      "metadata": {
        "id": "Yus9efFDWHof"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/sample.py \\\n",
        "out_dir=crystallm_v1_small \\\n",
        "start=FILE:sample_prompt.txt \\\n",
        "num_samples=2 \\\n",
        "top_k=10 \\\n",
        "max_new_tokens=3000 \\\n",
        "device=cpu \\\n",
        "target=file"
      ],
      "metadata": {
        "id": "FlvFVDsKWJaS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Post-process the generated raw CIF files."
      ],
      "metadata": {
        "id": "Z7XhLPA1Wggq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/postprocess.py . colab_processed_cifs"
      ],
      "metadata": {
        "id": "Zh3eHGJ4Wj8N"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Alternatively, we can do Monte Carlo Tree Search decoding to generate CIF files from the given prompt."
      ],
      "metadata": {
        "id": "6LjpSrMmWkYh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/mcts.py \\\n",
        "out_dir=crystallm_v1_small \\\n",
        "device=cpu \\\n",
        "dtype=bfloat16 \\\n",
        "start=FILE:sample_prompt.txt \\\n",
        "tree_width=5 \\\n",
        "max_depth=2000 \\\n",
        "selector=puct \\\n",
        "c=1.0 \\\n",
        "num_simulations=1000 \\\n",
        "reward_k=2.0 \\\n",
        "scorer=random \\\n",
        "top_child_weight_cutoff=0.9999 \\\n",
        "bypass_only_child=True \\\n",
        "mcts_out_dir=colab_mcts_cifs"
      ],
      "metadata": {
        "id": "KMny0DmdWoPy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### CIF Files Evaluation"
      ],
      "metadata": {
        "id": "GULePac0myNN"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Save the generated CIF files into a tar.gz file, as this is a requirement for the provided Python script for evaluation."
      ],
      "metadata": {
        "id": "BZVPpkktrEx0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!tar -czvf colab_processed_cifs.tar.gz ./colab_processed_cifs/"
      ],
      "metadata": {
        "id": "TViQn1Gprj-b"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Perform the evaluation and save the results to a CSV file."
      ],
      "metadata": {
        "id": "Cfa_jcMRqbPG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python bin/evaluate_cifs.py colab_processed_cifs.tar.gz -o colab_processed_cifs.csv"
      ],
      "metadata": {
        "id": "5ye8mFSSqf9-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Optional: Prepare the Trained Model for Push to the Hugging Face Hub"
      ],
      "metadata": {
        "id": "qh3lKIOZRjbQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This section is about showcasing the steps to make a CrystaLLM pretrained model available through the HF's Transformers API and share it in the HF's Hub. The code below is just for educational purposes, please refrain to share the original CrystaLLM pretrained model without consent from the authors. You can still tune further the original models on your own CIF files dataset, but please stay compliant to any update about the OS license for the CrystaLLM work before taking any further action. Thanks."
      ],
      "metadata": {
        "id": "4nS-4OQ9cH93"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Upgrade Numpy to the latest version."
      ],
      "metadata": {
        "id": "MdnyiVe9TKjD"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -U numpy"
      ],
      "metadata": {
        "id": "090dYd6rRp0h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create the configuration file for the model.\n",
        "\n"
      ],
      "metadata": {
        "id": "mV6XPsswRvBL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "crystall_small_config = \"\"\"{\n",
        "  \"bias\": true,\n",
        "  \"model_type\": \"gpt2\",\n",
        "  \"block_size\": 1024,\n",
        "  \"dropout\": 0.1,\n",
        "  \"n_embd\": 512,\n",
        "  \"n_head\": 8,\n",
        "  \"n_layer\": 8,\n",
        "  \"vocab_size\": 50257\n",
        "}\"\"\"\n",
        "\n",
        "with open('/content/CrystaLLM/crystallm_v1_small/config.json', 'a') as f:\n",
        "    f.write(crystall_small_config)"
      ],
      "metadata": {
        "id": "Sh0fAFkWS1_v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Convert the original model checkpoints from .pt to .bin format."
      ],
      "metadata": {
        "id": "r_LB92ADdWJP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "# Load the checkpoint file\n",
        "checkpoint = torch.load(\"/content/CrystaLLM/crystallm_v1_small/ckpt.pt\", map_location=torch.device('cpu'))\n",
        "\n",
        "# Extract the model parameters\n",
        "params = checkpoint[\"model\"]\n",
        "\n",
        "# Save the parameters to a .bin file\n",
        "torch.save(params, \"/content/CrystaLLM/crystallm_v1_small/pytorch_model.bin\")"
      ],
      "metadata": {
        "id": "dMn2xLppYKAx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Add configuration file and converted checkpoints to the Transformer API (use the generic AutoConfig and AutoModelForCausalLM classes)."
      ],
      "metadata": {
        "id": "3njTEYBodk0C"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoConfig, AutoModelForCausalLM\n",
        "\n",
        "config = AutoConfig.from_pretrained('/content/CrystaLLM/crystallm_v1_small/config.json')\n",
        "transformer_model = AutoModelForCausalLM.from_pretrained('/content/CrystaLLM/crystallm_v1_small', config=config)"
      ],
      "metadata": {
        "id": "DbxdhH9CR8BJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Save the model using the Transformers API (the method invoked in the code cell below takes care of generating also any other required accessory file)."
      ],
      "metadata": {
        "id": "p0bsBM7kdy8H"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "transformer_model.save_pretrained(\"/content/CrystaLLM/crystallm_v1_small_hf\")"
      ],
      "metadata": {
        "id": "MWsjb6bbZsMl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The model is now ready to be uploaded to the HF's Hub (assuming you have a valid Hugging Face profile). Please read the statement at the beginning of this section about permission to share the checkpoints through the Hub."
      ],
      "metadata": {
        "id": "Hhqc7diGeBK6"
      }
    }
  ]
}